import os
import yaml
import stat


def load_yaml():
    current_dir = os.path.dirname(os.path.abspath(__file__))
    project_dir = os.path.dirname(current_dir)
    yaml_path = os.path.join(project_dir, 'config.yaml')

    if not os.path.exists(yaml_path):
        print(f"{yaml_path} 不存在，正在创建默认配置文件...")
        default_config = {
            'model_path': 'path/to/model',
            'offload_path': 'path/to/offload',
            'dataset_path': 'path/to/dataset',
            'ffn_path': current_dir,
        }

        # 创建文件并设置权限 (兼容 Windows 和 Linux)
        with open(yaml_path, 'w') as f:
            yaml.dump(default_config, f)

        print(f"已创建默认配置文件：{yaml_path}")

        # 在非Windows系统上设置文件权限为 0o644 (rw-r--r--)

        if os.name != 'nt':  # 'nt' 代表 Windows
            os.chmod(yaml_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH)

    # 加载 YAML 文件
    with open(yaml_path, 'r') as f:
        config = yaml.safe_load(f)

    return config


if __name__ == '__main__':
    config = load_yaml()
    print(config)
